import torch
import torch.nn as nn

# 构造一个简单输入
# 假设 batch_size = 2，每个样本有 3 个特征，写成张量：
x = torch.tensor([[1.0, 2.0, 3.0],
                  [4.0, 5.0, 6.0]])
print(x.shape)
# torch.Size([2, 3])

# 创建 LayerNorm 实例
#   我们要对“最后一个维度（3 个特征）”做归一化，
#   因此 normalized_shape = 3。
# elementwise_affine=True 表示归一化后还会做
#   缩放 γ 和 平移 β（可学习，初始 γ=1，β=0）。
ln = nn.LayerNorm(normalized_shape=3, eps=1e-5, elementwise_affine=True)

"""
3.1 计算均值 μ 和方差 σ²
    对每个样本沿着最后一维算：
    样本 0: μ₀ = (1+2+3)/3 = 2.0
        σ²₀ = [(1-2)²+(2-2)²+(3-2)²]/3 = (1+0+1)/3 = 0.6667
    样本 1: μ₁ = (4+5+6)/3 = 5.0
        σ²₁ = [(4-5)²+(5-5)²+(6-5)²]/3 = (1+0+1)/3 = 0.6667
3.2 归一化
    公式：
        y = (x - μ) / sqrt(σ² + eps)
    样本 0:
        (1-2)/√(0.6667+1e-5) ≈ -1.2247
        (2-2)/√... = 0
        (3-2)/√... ≈ 1.2247
    样本 1:
        (4-5)/√(0.6667+1e-5) ≈ -1.2247
        (5-5)/√... = 0
        (6-5)/√... ≈ 1.2247
3.3 缩放 + 平移
    当前 γ=[1,1,1]，β=[0,0,0]，所以结果不变。
"""
y = ln(x)
print(y)
# tensor([[-1.2247,  0.0000,  1.2247],
#         [-1.2247,  0.0000,  1.2247]], grad_fn=<NativeLayerNormBackward0>)

